from osgeo import gdal
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from collections import Counter
import time
import os
# 读取影像
def read_img(input_file):
    """
        读取影像
        Inputs:
        input_file:图像数据的路径
        Outputs:
        array_data：保存图像数据的numpy.array数组
        rows:高度
        cols:宽度
        bands:深度
    """
    in_ds = gdal.Open(input_file)
    rows = in_ds.RasterYSize  # 获取数据高度
    cols = in_ds.RasterXSize  # 获取数据宽度
    bands = in_ds.RasterCount  # 获取数据波段数
    # GDT_Byte = 1, GDT_UInt16 = 2, GDT_UInt32 = 4, GDT_Int32 = 5, GDT_Float32 = 6
    datatype = in_ds.GetRasterBand(1).DataType
    print("数据类型：", datatype)
    array_data = in_ds.ReadAsArray()  # 将数据写成数组，读取全部数据，numpy数组,array_data.shape  (4, 36786, 37239) ,波段，行，列
    del in_ds
    # 返回数组对象、矩阵的wh及波段数
    return array_data, rows, cols, bands
# 写入影像
def write_img(read_path, img_array):
    """
    read_path:原始文件路径
    img_array：numpy数组
    """
    read_pre_dataset = gdal.Open(read_path)
    img_transf = read_pre_dataset.GetGeoTransform()  # 仿射矩阵
    img_proj = read_pre_dataset.GetProjection()  # 地图投影信息
    print("1readshape:", img_array.shape,img_array.dtype.name)

    # GDT_Byte = 1, GDT_UInt16 = 2, GDT_UInt32 = 4, GDT_Int32 = 5, GDT_Float32 = 6,
    if 'uint8' in img_array.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in img_array.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(img_array.shape) == 3:
        img_bands, im_height, im_width = img_array.shape
    else:
        img_bands, (im_height, im_width) = 1, img_array.shape

    filename = read_path[:-4] + '_unit8' + ".tif"
    driver = gdal.GetDriverByName("GTiff")  # 创建文件驱动
    # 注意这里先写宽再写高，对应数组的高和宽，这里不对应才对
    # https://vimsky.com/examples/detail/python-method-gdal.GetDriverByName.html
    dataset = driver.Create(filename, im_width, im_height, img_bands, datatype)
    dataset.SetGeoTransform(img_transf)  # 写入仿射变换参数
    dataset.SetProjection(img_proj)  # 写入投影

    # 写入影像数据
    if img_bands == 1:
        dataset.GetRasterBand(1).WriteArray(img_array)
    else:
        for i in range(img_bands):
            dataset.GetRasterBand(i + 1).WriteArray(img_array[i])
def compress(origin_16,low_per_raw=0.01,high_per_raw=0.96):
    """
    Input:
    origin_16:16位图像路径
    low_per=0.4   0.4%分位数，百分比截断的低点
    high_per=99.6  99.6%分位数，百分比截断的高点
    Output:
    output:8位图像路径
    """
    # 获取图像的数组文件
    array_data, rows, cols, bands = read_img(origin_16) # array_data, (4, 36786, 37239) ,波段，行，列
    print("1shape:", array_data.shape)

    # 这里控制要输出的是几位，np.zeros（）返回来一个给定形状和类型的用0填充的数组
    compress_data = np.zeros((bands,rows, cols),dtype="uint8")

    for i in range(bands):
        # 得到百分比对应的值，得到0代表的黑边的占比
        cnt_array = np.where(array_data[i, :, :], 0, 1)
        num0 = np.sum(cnt_array)

        kk = num0 / (rows * cols)    # 得到0的比例
        print(f'0的占比为{kk}')
        # 这里是去掉黑边0值，否则和arcgis的不一样，这样也有偏差，只算剩下的百分比
        low_per = low_per_raw + kk - low_per_raw * kk  # (A*x-A*KK)/(A-A*KK)=0.01, X = 0.99KK+0.01
        low_per = low_per * 100
        high_per = (1 - high_per_raw) * (1 - kk)  # A*x/(A-A*KK) = 0.04, X =  0.04-(0.04*kk)
        high_per = 100 - high_per * 100
        print("baifen:", low_per, high_per)

        cutmin = np.percentile(array_data[i, :, :], low_per)
        cutmax = np.percentile(array_data[i, :, :], high_per)
        print("duandian:",cutmin,cutmax)

        data_band = array_data[i]
        # 进行截断
        data_band[data_band<cutmin] = cutmin
        data_band[data_band>cutmax] = cutmax
        # 进行缩放
        compress_data[i,:, :] = np.around( (data_band[:,:] - cutmin) *255/(cutmax - cutmin) )



    print("maxANDmin：",np.max(compress_data),np.min(compress_data))
    # 下面可以直接将处理后的遥感图像输出，在神经网络预测的时候可以删掉这一句
    write_img(origin_16, compress_data)
    # return compress_data
if __name__ == '__main__':
    lissst = []
    pathh = r"G:\02工作空间"
    listdir(pathh , lissst)
    print(len(lissst))
    #print(lissst)

    for file_name in lissst:
        print(file_name)
        start = time.time()
        ds = file_name
        compress(ds)
        print("time_cost:",time.time()-start)